package service

import (
	"context"
	"errors"
	"fmt"
	"github.com/aliyun/alibabacloud-kms-agent/internal/conf"
	"github.com/aliyun/alibabacloud-kms-agent/internal/logger"
	"github.com/aliyun/alibabacloud-kms-agent/internal/model"
	"io/ioutil"
	"net"
	"net/http"
	"os"
	"os/signal"
	"runtime/debug"
	"strings"
	"sync/atomic"
	"syscall"
	"time"

	"github.com/aliyun/alibabacloud-kms-agent/internal/cache"
	"github.com/aliyun/alibabacloud-kms-agent/internal/kms"
)

const (
	PingPath       = "/ping"
	QueryStylePath = "/secretsmanager/get"

	defaultRequestTimeout = 30 * time.Second
)

type Server struct {
	listener              *net.TCPListener
	cacheStore            cache.Cache
	kmsClient             *kms.KeyManagementService
	loggerWrapper         logger.Wrapper
	ssrfToken             string
	ssrfHeaders           []string
	pathPrefix            string
	maxConn               int
	responseType          model.ResponseType
	IgnoreTransientErrors bool
	DisableSSRFToken      bool
	connCount             int32
}

func NewServer(serverConfig conf.ServerConfig, cacheStore cache.Cache,
	kmsClient *kms.KeyManagementService, lw logger.Wrapper) (*Server, error) {
	addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", *serverConfig.HttpPort))
	if err != nil {
		return nil, err
	}
	listener, err := net.ListenTCP("tcp", addr)
	if err != nil {
		return nil, err
	}

	var token string
	if !*serverConfig.DisableSSRFToken {
		token, err = getToken(*serverConfig.SSRFEnvVariables)
		if err != nil {
			return nil, err
		}
	}

	return &Server{
		listener:              listener,
		cacheStore:            cacheStore,
		kmsClient:             kmsClient,
		loggerWrapper:         lw,
		ssrfToken:             token,
		ssrfHeaders:           *serverConfig.SSRFHeaders,
		pathPrefix:            *serverConfig.PathPrefix,
		maxConn:               *serverConfig.MaxConn,
		responseType:          model.ResponseType(*serverConfig.ResponseType),
		IgnoreTransientErrors: *serverConfig.IgnoreTransientErrors,
		DisableSSRFToken:      *serverConfig.DisableSSRFToken,
	}, nil
}

func (s *Server) Serve() {
	mux := http.NewServeMux()
	mux.HandleFunc("/", s.handleRequest)
	server := &http.Server{
		Handler:      mux,
		ReadTimeout:  defaultRequestTimeout,
		WriteTimeout: defaultRequestTimeout,
	}
	go func() {
		s.loggerWrapper.Info("server listening on: %s", s.listener.Addr().String())
		if err := server.Serve(s.listener); err != nil {
			s.loggerWrapper.Error("server listening err: %v", err)
			os.Exit(1)
		}
	}()

	// 监听中断信号以优雅地关闭服务器
	signals := make(chan os.Signal, 1)
	signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
	<-signals
	s.loggerWrapper.Info("shutting down the server")
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	if err := server.Shutdown(ctx); err != nil {
		s.loggerWrapper.Error("server shutdown Failed :%v", err)
	}
	s.loggerWrapper.Info("server exiting")
}

func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
	//defer recover
	defer func(w http.ResponseWriter, r *http.Request) {
		if err := recover(); err != nil {
			w.WriteHeader(http.StatusInternalServerError)
			s.loggerWrapper.Error("HandlePanic:%v\tTraceStack:%s", err, string(debug.Stack()))
		}
	}(w, r)

	if err := s.validateMaxConn(r); err != nil {
		http.Error(w, err.Error(), http.StatusTooManyRequests)
		return
	}
	defer atomic.AddInt32(&s.connCount, -1)

	if err := s.validateToken(r); err != nil {
		http.Error(w, err.Error(), http.StatusForbidden)
		return
	}

	if err := s.validateMethod(r); err != nil {
		http.Error(w, err.Error(), http.StatusMethodNotAllowed)
		return
	}

	// Ping
	if r.URL.Path == PingPath {
		s.handlePing(w, r)
		return
	}

	// Get Secret
	if r.URL.Path == QueryStylePath || strings.HasPrefix(r.URL.Path, s.pathPrefix) {
		s.handleGetSecret(w, r)
		return
	}

	http.NotFound(w, r)
}
func (s *Server) validateMaxConn(r *http.Request) error {
	isPing := r.URL.Path == "/ping"
	limit := s.maxConn + 1
	if isPing {
		limit += 3
	}

	count := atomic.AddInt32(&s.connCount, 1)
	if count >= int32(limit) {
		return errors.New("connection limit exceeded")
	}

	return nil
}

func (s *Server) validateToken(r *http.Request) error {
	if r.URL.Path == "/ping" {
		return nil
	}

	if _, ok := r.Header["X-Forwarded-For"]; ok {
		errors.New("forwarded")
	}

	if s.DisableSSRFToken {
		return nil
	}

	for _, header := range s.ssrfHeaders {
		if token := r.Header.Get(header); token == s.ssrfToken {
			return nil
		}
	}

	return errors.New("bad token")
}

func (s *Server) validateMethod(r *http.Request) error {
	if r.Method != http.MethodGet {
		return errors.New("http method just allowed get")
	}
	return nil
}

func getToken(envs []string) (string, error) {
	var found string
	for _, envName := range envs {
		val, exists := os.LookupEnv(envName)
		if exists {
			found = val
			break
		}
	}

	if found == "" {
		return "", errors.New("environment variable not present, you must set one valid SSRFEnvVariable")
	}

	if !strings.HasPrefix(found, "file://") {
		return found, nil
	}

	file := strings.TrimPrefix(found, "file://")
	content, err := ioutil.ReadFile(file)
	if err != nil {
		return "", err
	}

	return strings.TrimSpace(string(content)), nil
}
